Practice 2. Recurrent Neural Networks¶
- Alejandro Dopico Castro (alejandro.dopico2@udc.es).
- Ana Xiangning Pereira Ezquerro (ana.ezquerro@udc.es).
The following notebook contains execution examples of the recurrent neural architecture proposed for the Walmart dataset. The Python scripts submitted include auxiliar code to simplify the readibility of the code cells.
- data.py: Includes the
WalmartDatasetclass to instantiate each dataset. - model.py: Includes the
WalmartModelclass to instantiate a model with fixed hyperparameters and theDenormalizedMAEmetric to use in thefit()Keras method. - plots.py: Includes auxiliary functions to display the time series performance of model predictions.
Note: To properly visualize and interact with the Plotly graphs we recommend using the [walmart.html] file.
from data import *
from plots import *
from model import *
from utils import *
from keras.layers import *
from keras.models import Sequential, Model
from keras.optimizers import Adam, Optimizer, RMSprop
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.regularizers import L1, L2, L1L2
from tensorflow.data import Dataset
from itertools import product
from collections import OrderedDict
import plotly.offline as pyo
pyo.init_notebook_mode()
Regularizer.__str__ = lambda x: str(x.__class__.__name__)
Optimizer.__str__ = lambda x: str(x.__class__.__name__) + f'({float(x.learning_rate.numpy()):1.0e})'
# global parameters
TEST_RATIO = 0.2
VAL_RATIO = 0.15
BATCH_SIZE = 200
# load data
data = WalmartDataset.load('Walmart.csv')
train, val, test = data.split(VAL_RATIO, TEST_RATIO)
Recurrent Neural Model¶
To model the temporal relations in the stream data, our neural architecture is a recurrent encoder ($\mathcal{E}$) with $\ell$ hidden layers of dimension $d_h$ that project the input sequence $\mathbf{X}\in\mathbb{R}^{S\times d_x}$ to a time-contextualized sequence of embeddings $\mathbf{H} = \mathcal{E}(\mathbf{X}) \in \mathbb{R}^{S\times d_h}$ (where $d_x$ and $d_h$ denote respectively the number of input features and the hidden dimension of the model and $S$ denote the sequence length). The result $\mathbf{H}$ is passed through a final recurrent layer (LSTM-based) and the final state $\tilde{\mathbf{h}}\in\mathbb{R}^{d_h}$ is used as a summarization of the sample. This representation is then passed to a feed-forward decoder composed of $\varphi$ dense layers, where the last one is constrained with a linear activation to predict the output value $\hat{y}$ (number of sales expected for the timestep $t+2$).
In this section we explored three possible values for the hyperparamenter $S$ to validate the impact of the past observation in the sales modelling, maintaining the other hyperparemeters (number of layers, model dimension, activations, etc.) with default values. The default configuration (baseline) uses an encoder of 2-stacked LSTMs with a decoder of 2 feed-forward networks. The only regularization method used is dropout (10%). This naive network can be easily improved, but we decided to start with the simplest architecture and incrementally increase the complexity of the model while controlling the overfitting with regularization methods.
# S = 2
model2 = WalmartModel(2, hidden_size=10)
model2.train(train, val, 'results/walmart2.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model2.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 45ms/step - dmae: 542466.0625 - loss: 1.2749 - mae: 0.9731 - val_dmae: 482828.5000 - val_loss: 1.1356 - val_mae: 0.8661 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 529696.8750 - loss: 1.2212 - mae: 0.9502 - val_dmae: 451427.3438 - val_loss: 0.9896 - val_mae: 0.8098 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 476533.1250 - loss: 1.0042 - mae: 0.8548 - val_dmae: 320770.8438 - val_loss: 0.4981 - val_mae: 0.5754 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 283703.5625 - loss: 0.3883 - mae: 0.5089 - val_dmae: 196077.3125 - val_loss: 0.2602 - val_mae: 0.3517 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 133294.7500 - loss: 0.1291 - mae: 0.2391 - val_dmae: 177371.0000 - val_loss: 0.2245 - val_mae: 0.3182 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120702.5078 - loss: 0.1169 - mae: 0.2165 - val_dmae: 173767.5781 - val_loss: 0.2200 - val_mae: 0.3117 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117085.4062 - loss: 0.1121 - mae: 0.2100 - val_dmae: 171468.9688 - val_loss: 0.2138 - val_mae: 0.3076 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115684.9531 - loss: 0.1109 - mae: 0.2075 - val_dmae: 170553.6406 - val_loss: 0.2112 - val_mae: 0.3060 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113892.6484 - loss: 0.1075 - mae: 0.2043 - val_dmae: 169441.3125 - val_loss: 0.2082 - val_mae: 0.3040 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112650.3906 - loss: 0.1062 - mae: 0.2021 - val_dmae: 168496.5781 - val_loss: 0.2052 - val_mae: 0.3023 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113185.1719 - loss: 0.1087 - mae: 0.2030 - val_dmae: 167545.3125 - val_loss: 0.2021 - val_mae: 0.3006 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113233.1484 - loss: 0.1059 - mae: 0.2031 - val_dmae: 167308.2500 - val_loss: 0.1998 - val_mae: 0.3001 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112262.5703 - loss: 0.1044 - mae: 0.2014 - val_dmae: 167006.6406 - val_loss: 0.1984 - val_mae: 0.2996 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113321.8750 - loss: 0.1046 - mae: 0.2033 - val_dmae: 167198.8125 - val_loss: 0.1971 - val_mae: 0.2999 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 113176.3359 - loss: 0.1042 - mae: 0.2030 - val_dmae: 166525.8594 - val_loss: 0.1947 - val_mae: 0.2987 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112381.8125 - loss: 0.1026 - mae: 0.2016 - val_dmae: 165626.2031 - val_loss: 0.1924 - val_mae: 0.2971 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112691.5000 - loss: 0.1029 - mae: 0.2022 - val_dmae: 165702.7500 - val_loss: 0.1905 - val_mae: 0.2973 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 113656.5312 - loss: 0.1049 - mae: 0.2039 - val_dmae: 164918.7500 - val_loss: 0.1880 - val_mae: 0.2958 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113891.9297 - loss: 0.1029 - mae: 0.2043 - val_dmae: 164257.3125 - val_loss: 0.1861 - val_mae: 0.2947 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113350.8906 - loss: 0.1020 - mae: 0.2033 - val_dmae: 163191.5312 - val_loss: 0.1831 - val_mae: 0.2927 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112985.3516 - loss: 0.1034 - mae: 0.2027 - val_dmae: 162745.6562 - val_loss: 0.1816 - val_mae: 0.2919 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113113.8594 - loss: 0.1030 - mae: 0.2029 - val_dmae: 162130.7500 - val_loss: 0.1797 - val_mae: 0.2908 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112037.4297 - loss: 0.0988 - mae: 0.2010 - val_dmae: 160907.2188 - val_loss: 0.1766 - val_mae: 0.2887 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112930.2344 - loss: 0.1011 - mae: 0.2026 - val_dmae: 160469.9688 - val_loss: 0.1751 - val_mae: 0.2879 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111846.2500 - loss: 0.0988 - mae: 0.2006 - val_dmae: 159643.8750 - val_loss: 0.1730 - val_mae: 0.2864 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112309.9297 - loss: 0.0995 - mae: 0.2015 - val_dmae: 158771.9375 - val_loss: 0.1702 - val_mae: 0.2848 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110020.6875 - loss: 0.0967 - mae: 0.1974 - val_dmae: 158217.5156 - val_loss: 0.1692 - val_mae: 0.2838 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111687.5703 - loss: 0.0975 - mae: 0.2004 - val_dmae: 157435.3281 - val_loss: 0.1672 - val_mae: 0.2824 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111681.6797 - loss: 0.0976 - mae: 0.2003 - val_dmae: 156246.0625 - val_loss: 0.1646 - val_mae: 0.2803 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112036.5859 - loss: 0.1020 - mae: 0.2010 - val_dmae: 155778.1562 - val_loss: 0.1635 - val_mae: 0.2794 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111349.8750 - loss: 0.0993 - mae: 0.1997 - val_dmae: 155304.1250 - val_loss: 0.1619 - val_mae: 0.2786 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109372.3359 - loss: 0.0954 - mae: 0.1962 - val_dmae: 154837.9844 - val_loss: 0.1608 - val_mae: 0.2778 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110124.2031 - loss: 0.0962 - mae: 0.1976 - val_dmae: 153661.8125 - val_loss: 0.1587 - val_mae: 0.2757 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111761.6094 - loss: 0.0966 - mae: 0.2005 - val_dmae: 153093.3750 - val_loss: 0.1575 - val_mae: 0.2746 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109890.2266 - loss: 0.0937 - mae: 0.1971 - val_dmae: 151940.3281 - val_loss: 0.1552 - val_mae: 0.2726 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110978.9141 - loss: 0.0949 - mae: 0.1991 - val_dmae: 152149.7344 - val_loss: 0.1549 - val_mae: 0.2729 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111047.1562 - loss: 0.0962 - mae: 0.1992 - val_dmae: 151276.1406 - val_loss: 0.1535 - val_mae: 0.2714 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110416.0312 - loss: 0.0932 - mae: 0.1981 - val_dmae: 150854.2344 - val_loss: 0.1522 - val_mae: 0.2706 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108360.7578 - loss: 0.0916 - mae: 0.1944 - val_dmae: 149624.4531 - val_loss: 0.1505 - val_mae: 0.2684 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110382.8125 - loss: 0.0931 - mae: 0.1980 - val_dmae: 148468.9688 - val_loss: 0.1488 - val_mae: 0.2663 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 108698.6172 - loss: 0.0894 - mae: 0.1950 - val_dmae: 148712.5781 - val_loss: 0.1482 - val_mae: 0.2668 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109725.9766 - loss: 0.0910 - mae: 0.1968 - val_dmae: 148072.0312 - val_loss: 0.1471 - val_mae: 0.2656 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108324.7344 - loss: 0.0889 - mae: 0.1943 - val_dmae: 147546.8750 - val_loss: 0.1466 - val_mae: 0.2647 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109750.2891 - loss: 0.0924 - mae: 0.1969 - val_dmae: 147162.1562 - val_loss: 0.1459 - val_mae: 0.2640 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108304.9375 - loss: 0.0898 - mae: 0.1943 - val_dmae: 146951.3281 - val_loss: 0.1453 - val_mae: 0.2636 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107908.9453 - loss: 0.0876 - mae: 0.1936 - val_dmae: 146345.2031 - val_loss: 0.1441 - val_mae: 0.2625 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111194.6484 - loss: 0.0939 - mae: 0.1995 - val_dmae: 146156.3906 - val_loss: 0.1435 - val_mae: 0.2622 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108984.6250 - loss: 0.0897 - mae: 0.1955 - val_dmae: 146511.0469 - val_loss: 0.1434 - val_mae: 0.2628 Epoch 49/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107806.2812 - loss: 0.0873 - mae: 0.1934 - val_dmae: 145242.2188 - val_loss: 0.1423 - val_mae: 0.2605 Epoch 50/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 106928.0391 - loss: 0.0884 - mae: 0.1918 - val_dmae: 145242.1406 - val_loss: 0.1418 - val_mae: 0.2605 Epoch 51/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 109188.1562 - loss: 0.0915 - mae: 0.1959 - val_dmae: 145059.5781 - val_loss: 0.1416 - val_mae: 0.2602 Epoch 52/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108148.0391 - loss: 0.0880 - mae: 0.1940 - val_dmae: 144448.9844 - val_loss: 0.1408 - val_mae: 0.2591 Epoch 53/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107990.7656 - loss: 0.0888 - mae: 0.1937 - val_dmae: 144357.6094 - val_loss: 0.1404 - val_mae: 0.2590 Epoch 54/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107526.2656 - loss: 0.0877 - mae: 0.1929 - val_dmae: 143896.8906 - val_loss: 0.1398 - val_mae: 0.2581 Epoch 55/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107085.8594 - loss: 0.0867 - mae: 0.1921 - val_dmae: 143877.1562 - val_loss: 0.1395 - val_mae: 0.2581 Epoch 56/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107904.9688 - loss: 0.0883 - mae: 0.1936 - val_dmae: 143210.7031 - val_loss: 0.1387 - val_mae: 0.2569 Epoch 57/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108230.6016 - loss: 0.0871 - mae: 0.1942 - val_dmae: 143528.7812 - val_loss: 0.1388 - val_mae: 0.2575 Epoch 58/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108813.1562 - loss: 0.0887 - mae: 0.1952 - val_dmae: 143870.9375 - val_loss: 0.1391 - val_mae: 0.2581 Epoch 59/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108596.9219 - loss: 0.0884 - mae: 0.1948 - val_dmae: 142628.6406 - val_loss: 0.1383 - val_mae: 0.2559 Epoch 60/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107074.8047 - loss: 0.0867 - mae: 0.1921 - val_dmae: 143124.1719 - val_loss: 0.1380 - val_mae: 0.2567 Epoch 61/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108068.4219 - loss: 0.0877 - mae: 0.1939 - val_dmae: 143261.3281 - val_loss: 0.1381 - val_mae: 0.2570 Epoch 62/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 3s 32ms/step - dmae: 107847.1641 - loss: 0.0868 - mae: 0.1935 - val_dmae: 142399.4844 - val_loss: 0.1370 - val_mae: 0.2554 Epoch 63/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108915.8359 - loss: 0.0876 - mae: 0.1954 - val_dmae: 142658.0625 - val_loss: 0.1369 - val_mae: 0.2559 Epoch 64/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106799.2969 - loss: 0.0860 - mae: 0.1916 - val_dmae: 142327.2188 - val_loss: 0.1363 - val_mae: 0.2553 Epoch 65/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105874.8125 - loss: 0.0833 - mae: 0.1899 - val_dmae: 142008.0625 - val_loss: 0.1363 - val_mae: 0.2547 Epoch 66/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108118.6641 - loss: 0.0855 - mae: 0.1940 - val_dmae: 141454.9531 - val_loss: 0.1356 - val_mae: 0.2538 Epoch 67/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109463.1953 - loss: 0.0893 - mae: 0.1964 - val_dmae: 141685.8750 - val_loss: 0.1352 - val_mae: 0.2542 Epoch 68/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105825.5234 - loss: 0.0858 - mae: 0.1898 - val_dmae: 140991.7812 - val_loss: 0.1351 - val_mae: 0.2529 Epoch 69/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106448.1719 - loss: 0.0838 - mae: 0.1910 - val_dmae: 140771.2031 - val_loss: 0.1346 - val_mae: 0.2525 Epoch 70/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104940.3516 - loss: 0.0825 - mae: 0.1883 - val_dmae: 141016.3750 - val_loss: 0.1346 - val_mae: 0.2530 Epoch 71/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106284.2344 - loss: 0.0846 - mae: 0.1907 - val_dmae: 140602.1562 - val_loss: 0.1338 - val_mae: 0.2522 Epoch 72/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106217.8594 - loss: 0.0820 - mae: 0.1905 - val_dmae: 140164.5625 - val_loss: 0.1336 - val_mae: 0.2514 Epoch 73/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105000.2891 - loss: 0.0818 - mae: 0.1884 - val_dmae: 140387.6406 - val_loss: 0.1332 - val_mae: 0.2518 Epoch 74/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106409.8359 - loss: 0.0834 - mae: 0.1909 - val_dmae: 139983.8594 - val_loss: 0.1331 - val_mae: 0.2511 Epoch 75/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106516.8438 - loss: 0.0852 - mae: 0.1911 - val_dmae: 139730.4844 - val_loss: 0.1326 - val_mae: 0.2507 Epoch 76/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105437.2344 - loss: 0.0820 - mae: 0.1891 - val_dmae: 139861.8281 - val_loss: 0.1328 - val_mae: 0.2509 Epoch 77/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106327.5234 - loss: 0.0846 - mae: 0.1907 - val_dmae: 139872.5938 - val_loss: 0.1324 - val_mae: 0.2509 Epoch 78/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106146.8438 - loss: 0.0838 - mae: 0.1904 - val_dmae: 139626.1875 - val_loss: 0.1321 - val_mae: 0.2505 Epoch 79/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107474.9922 - loss: 0.0847 - mae: 0.1928 - val_dmae: 139721.1719 - val_loss: 0.1324 - val_mae: 0.2506 Epoch 80/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107348.9219 - loss: 0.0847 - mae: 0.1926 - val_dmae: 139577.4062 - val_loss: 0.1320 - val_mae: 0.2504 Epoch 81/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106979.7734 - loss: 0.0847 - mae: 0.1919 - val_dmae: 138549.5312 - val_loss: 0.1311 - val_mae: 0.2485 Epoch 82/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105385.9531 - loss: 0.0832 - mae: 0.1891 - val_dmae: 138810.9219 - val_loss: 0.1309 - val_mae: 0.2490 Epoch 83/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105964.1875 - loss: 0.0833 - mae: 0.1901 - val_dmae: 139345.7812 - val_loss: 0.1312 - val_mae: 0.2500 Epoch 84/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104889.5781 - loss: 0.0820 - mae: 0.1882 - val_dmae: 138710.0625 - val_loss: 0.1308 - val_mae: 0.2488 Epoch 85/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 105695.7031 - loss: 0.0836 - mae: 0.1896 - val_dmae: 138363.7500 - val_loss: 0.1302 - val_mae: 0.2482 Epoch 86/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 105218.9688 - loss: 0.0839 - mae: 0.1888 - val_dmae: 139187.0781 - val_loss: 0.1305 - val_mae: 0.2497 Epoch 87/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105252.5781 - loss: 0.0844 - mae: 0.1888 - val_dmae: 138020.2031 - val_loss: 0.1299 - val_mae: 0.2476 Epoch 88/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104995.2734 - loss: 0.0826 - mae: 0.1884 - val_dmae: 137290.6875 - val_loss: 0.1292 - val_mae: 0.2463 Epoch 89/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105696.9922 - loss: 0.0826 - mae: 0.1896 - val_dmae: 138386.2031 - val_loss: 0.1299 - val_mae: 0.2483 Epoch 90/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 104927.4141 - loss: 0.0818 - mae: 0.1882 - val_dmae: 137169.7188 - val_loss: 0.1290 - val_mae: 0.2461 Epoch 91/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104667.6094 - loss: 0.0822 - mae: 0.1878 - val_dmae: 137768.0625 - val_loss: 0.1292 - val_mae: 0.2471 Epoch 92/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103445.3047 - loss: 0.0821 - mae: 0.1856 - val_dmae: 137818.1250 - val_loss: 0.1293 - val_mae: 0.2472 Epoch 93/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103454.4453 - loss: 0.0787 - mae: 0.1856 - val_dmae: 137255.6719 - val_loss: 0.1284 - val_mae: 0.2462 Epoch 94/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104167.6328 - loss: 0.0809 - mae: 0.1869 - val_dmae: 136743.0469 - val_loss: 0.1280 - val_mae: 0.2453 Epoch 95/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103535.7812 - loss: 0.0788 - mae: 0.1857 - val_dmae: 136691.6406 - val_loss: 0.1279 - val_mae: 0.2452 Epoch 96/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104826.0703 - loss: 0.0823 - mae: 0.1880 - val_dmae: 136810.1406 - val_loss: 0.1279 - val_mae: 0.2454 Epoch 97/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106355.6016 - loss: 0.0829 - mae: 0.1908 - val_dmae: 136496.6406 - val_loss: 0.1273 - val_mae: 0.2449 Epoch 98/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105626.4219 - loss: 0.0839 - mae: 0.1895 - val_dmae: 136440.7812 - val_loss: 0.1276 - val_mae: 0.2448 Epoch 99/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104637.9688 - loss: 0.0823 - mae: 0.1877 - val_dmae: 136642.5625 - val_loss: 0.1270 - val_mae: 0.2451 Epoch 100/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104655.5000 - loss: 0.0811 - mae: 0.1877 - val_dmae: 136200.2188 - val_loss: 0.1271 - val_mae: 0.2443 Epoch 101/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104018.0781 - loss: 0.0807 - mae: 0.1866 - val_dmae: 136431.9844 - val_loss: 0.1267 - val_mae: 0.2447 Epoch 102/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104705.2578 - loss: 0.0799 - mae: 0.1878 - val_dmae: 135655.1406 - val_loss: 0.1261 - val_mae: 0.2434 Epoch 103/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104673.5547 - loss: 0.0800 - mae: 0.1878 - val_dmae: 136366.9375 - val_loss: 0.1266 - val_mae: 0.2446 Epoch 104/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102701.0391 - loss: 0.0790 - mae: 0.1842 - val_dmae: 135848.4688 - val_loss: 0.1261 - val_mae: 0.2437 Epoch 105/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105535.8281 - loss: 0.0829 - mae: 0.1893 - val_dmae: 136197.0156 - val_loss: 0.1263 - val_mae: 0.2443 Epoch 106/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102814.2266 - loss: 0.0789 - mae: 0.1844 - val_dmae: 136579.7188 - val_loss: 0.1269 - val_mae: 0.2450 Epoch 107/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104104.2266 - loss: 0.0808 - mae: 0.1868 - val_dmae: 134461.6562 - val_loss: 0.1247 - val_mae: 0.2412 Epoch 108/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102041.6641 - loss: 0.0780 - mae: 0.1831 - val_dmae: 135754.6875 - val_loss: 0.1254 - val_mae: 0.2435 Epoch 109/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103270.1797 - loss: 0.0788 - mae: 0.1853 - val_dmae: 135350.5625 - val_loss: 0.1254 - val_mae: 0.2428 Epoch 110/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104910.7344 - loss: 0.0816 - mae: 0.1882 - val_dmae: 134735.2188 - val_loss: 0.1251 - val_mae: 0.2417 Epoch 111/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103093.6016 - loss: 0.0793 - mae: 0.1849 - val_dmae: 135348.1875 - val_loss: 0.1248 - val_mae: 0.2428 Epoch 112/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 104613.5078 - loss: 0.0824 - mae: 0.1877 - val_dmae: 136088.3594 - val_loss: 0.1257 - val_mae: 0.2441 Epoch 112: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 67069.3906 - loss: 0.0270 - mae: 0.1203
[0.02606853097677231, 64549.47265625, 0.11579488962888718]
# S = 3
model3 = WalmartModel(3, hidden_size=10)
model3.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model3.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 43ms/step - dmae: 540936.6250 - loss: 1.2712 - mae: 0.9704 - val_dmae: 446060.9375 - val_loss: 0.8851 - val_mae: 0.8002 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 526711.0000 - loss: 1.2101 - mae: 0.9449 - val_dmae: 417605.8438 - val_loss: 0.7719 - val_mae: 0.7491 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 472037.7812 - loss: 0.9910 - mae: 0.8468 - val_dmae: 290335.1875 - val_loss: 0.3701 - val_mae: 0.5208 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 257934.1094 - loss: 0.3433 - mae: 0.4627 - val_dmae: 186003.2031 - val_loss: 0.2247 - val_mae: 0.3337 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 132825.7969 - loss: 0.1373 - mae: 0.2383 - val_dmae: 171562.5156 - val_loss: 0.1824 - val_mae: 0.3078 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 122590.3438 - loss: 0.1296 - mae: 0.2199 - val_dmae: 168244.2969 - val_loss: 0.1790 - val_mae: 0.3018 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120699.1172 - loss: 0.1292 - mae: 0.2165 - val_dmae: 166717.6094 - val_loss: 0.1753 - val_mae: 0.2991 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119165.5938 - loss: 0.1258 - mae: 0.2138 - val_dmae: 165173.4531 - val_loss: 0.1717 - val_mae: 0.2963 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119174.4141 - loss: 0.1264 - mae: 0.2138 - val_dmae: 163359.1562 - val_loss: 0.1675 - val_mae: 0.2930 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118432.2266 - loss: 0.1264 - mae: 0.2125 - val_dmae: 162832.9062 - val_loss: 0.1660 - val_mae: 0.2921 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 117244.1484 - loss: 0.1241 - mae: 0.2103 - val_dmae: 162737.7969 - val_loss: 0.1639 - val_mae: 0.2919 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120068.3359 - loss: 0.1260 - mae: 0.2154 - val_dmae: 162561.1719 - val_loss: 0.1632 - val_mae: 0.2916 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117621.6562 - loss: 0.1246 - mae: 0.2110 - val_dmae: 161742.9531 - val_loss: 0.1609 - val_mae: 0.2901 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118960.2812 - loss: 0.1226 - mae: 0.2134 - val_dmae: 161482.3906 - val_loss: 0.1603 - val_mae: 0.2897 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118165.0078 - loss: 0.1215 - mae: 0.2120 - val_dmae: 160821.7812 - val_loss: 0.1583 - val_mae: 0.2885 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117926.8594 - loss: 0.1220 - mae: 0.2115 - val_dmae: 160160.2031 - val_loss: 0.1579 - val_mae: 0.2873 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 116689.5000 - loss: 0.1217 - mae: 0.2093 - val_dmae: 160171.0156 - val_loss: 0.1567 - val_mae: 0.2873 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117721.7969 - loss: 0.1217 - mae: 0.2112 - val_dmae: 159014.6875 - val_loss: 0.1535 - val_mae: 0.2853 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117070.3750 - loss: 0.1192 - mae: 0.2100 - val_dmae: 159033.2656 - val_loss: 0.1544 - val_mae: 0.2853 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116089.2422 - loss: 0.1179 - mae: 0.2083 - val_dmae: 158147.5156 - val_loss: 0.1514 - val_mae: 0.2837 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116260.5703 - loss: 0.1175 - mae: 0.2086 - val_dmae: 156836.6406 - val_loss: 0.1490 - val_mae: 0.2813 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116733.7031 - loss: 0.1175 - mae: 0.2094 - val_dmae: 156020.1406 - val_loss: 0.1466 - val_mae: 0.2799 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115904.0000 - loss: 0.1175 - mae: 0.2079 - val_dmae: 155950.4688 - val_loss: 0.1470 - val_mae: 0.2798 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116191.7578 - loss: 0.1182 - mae: 0.2084 - val_dmae: 155160.9844 - val_loss: 0.1458 - val_mae: 0.2783 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115599.6250 - loss: 0.1145 - mae: 0.2074 - val_dmae: 154370.4219 - val_loss: 0.1448 - val_mae: 0.2769 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115037.5391 - loss: 0.1174 - mae: 0.2064 - val_dmae: 153504.0938 - val_loss: 0.1440 - val_mae: 0.2754 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112945.7500 - loss: 0.1118 - mae: 0.2026 - val_dmae: 152736.9688 - val_loss: 0.1438 - val_mae: 0.2740 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113147.0781 - loss: 0.1107 - mae: 0.2030 - val_dmae: 151582.6719 - val_loss: 0.1406 - val_mae: 0.2719 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113631.4766 - loss: 0.1121 - mae: 0.2038 - val_dmae: 150585.4688 - val_loss: 0.1406 - val_mae: 0.2701 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112733.0781 - loss: 0.1119 - mae: 0.2022 - val_dmae: 149879.8594 - val_loss: 0.1394 - val_mae: 0.2689 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112611.1797 - loss: 0.1105 - mae: 0.2020 - val_dmae: 148606.7500 - val_loss: 0.1372 - val_mae: 0.2666 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111950.7891 - loss: 0.1089 - mae: 0.2008 - val_dmae: 147362.0312 - val_loss: 0.1373 - val_mae: 0.2644 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111432.6562 - loss: 0.1095 - mae: 0.1999 - val_dmae: 146865.7812 - val_loss: 0.1354 - val_mae: 0.2635 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112428.4688 - loss: 0.1088 - mae: 0.2017 - val_dmae: 145913.6562 - val_loss: 0.1349 - val_mae: 0.2618 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111081.5469 - loss: 0.1066 - mae: 0.1993 - val_dmae: 145309.6562 - val_loss: 0.1339 - val_mae: 0.2607 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110126.4844 - loss: 0.1060 - mae: 0.1976 - val_dmae: 143942.8750 - val_loss: 0.1329 - val_mae: 0.2582 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111330.8828 - loss: 0.1068 - mae: 0.1997 - val_dmae: 143556.3750 - val_loss: 0.1332 - val_mae: 0.2575 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110325.1953 - loss: 0.1053 - mae: 0.1979 - val_dmae: 141895.3438 - val_loss: 0.1299 - val_mae: 0.2545 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111159.8906 - loss: 0.1054 - mae: 0.1994 - val_dmae: 140494.4688 - val_loss: 0.1277 - val_mae: 0.2520 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109931.8906 - loss: 0.1042 - mae: 0.1972 - val_dmae: 139570.3281 - val_loss: 0.1268 - val_mae: 0.2504 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111023.9531 - loss: 0.1062 - mae: 0.1992 - val_dmae: 139635.0625 - val_loss: 0.1283 - val_mae: 0.2505 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108282.9219 - loss: 0.1023 - mae: 0.1942 - val_dmae: 138786.0156 - val_loss: 0.1265 - val_mae: 0.2490 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108149.6328 - loss: 0.1018 - mae: 0.1940 - val_dmae: 137035.6250 - val_loss: 0.1244 - val_mae: 0.2458 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108169.9219 - loss: 0.1028 - mae: 0.1940 - val_dmae: 137295.5312 - val_loss: 0.1251 - val_mae: 0.2463 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107773.6016 - loss: 0.1010 - mae: 0.1933 - val_dmae: 135757.5000 - val_loss: 0.1234 - val_mae: 0.2435 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107828.5312 - loss: 0.1023 - mae: 0.1934 - val_dmae: 135303.0625 - val_loss: 0.1226 - val_mae: 0.2427 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107727.1484 - loss: 0.1017 - mae: 0.1933 - val_dmae: 135700.1406 - val_loss: 0.1245 - val_mae: 0.2434 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108498.3203 - loss: 0.1000 - mae: 0.1946 - val_dmae: 134909.6875 - val_loss: 0.1239 - val_mae: 0.2420 Epoch 49/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107736.9609 - loss: 0.0997 - mae: 0.1933 - val_dmae: 133203.0312 - val_loss: 0.1201 - val_mae: 0.2390 Epoch 50/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108150.8750 - loss: 0.0993 - mae: 0.1940 - val_dmae: 133000.3594 - val_loss: 0.1191 - val_mae: 0.2386 Epoch 51/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107397.0469 - loss: 0.0992 - mae: 0.1927 - val_dmae: 132878.0625 - val_loss: 0.1196 - val_mae: 0.2384 Epoch 52/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109820.4297 - loss: 0.1019 - mae: 0.1970 - val_dmae: 131335.3594 - val_loss: 0.1181 - val_mae: 0.2356 Epoch 53/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107559.5000 - loss: 0.0980 - mae: 0.1930 - val_dmae: 132667.8906 - val_loss: 0.1199 - val_mae: 0.2380 Epoch 54/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106483.0859 - loss: 0.0995 - mae: 0.1910 - val_dmae: 131244.6406 - val_loss: 0.1177 - val_mae: 0.2354 Epoch 55/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107027.5000 - loss: 0.0969 - mae: 0.1920 - val_dmae: 130604.6719 - val_loss: 0.1171 - val_mae: 0.2343 Epoch 56/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106827.5938 - loss: 0.0976 - mae: 0.1916 - val_dmae: 131395.9531 - val_loss: 0.1184 - val_mae: 0.2357 Epoch 57/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107072.7266 - loss: 0.0979 - mae: 0.1921 - val_dmae: 129069.6094 - val_loss: 0.1142 - val_mae: 0.2315 Epoch 58/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107479.7734 - loss: 0.0988 - mae: 0.1928 - val_dmae: 129324.2578 - val_loss: 0.1153 - val_mae: 0.2320 Epoch 59/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104877.4062 - loss: 0.0967 - mae: 0.1881 - val_dmae: 130815.4141 - val_loss: 0.1172 - val_mae: 0.2347 Epoch 60/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106237.4375 - loss: 0.0994 - mae: 0.1906 - val_dmae: 128380.0391 - val_loss: 0.1135 - val_mae: 0.2303 Epoch 61/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105664.3203 - loss: 0.0957 - mae: 0.1896 - val_dmae: 128674.1875 - val_loss: 0.1139 - val_mae: 0.2308 Epoch 62/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106641.8828 - loss: 0.0975 - mae: 0.1913 - val_dmae: 127305.0312 - val_loss: 0.1122 - val_mae: 0.2284 Epoch 63/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107236.2109 - loss: 0.0992 - mae: 0.1924 - val_dmae: 127473.6016 - val_loss: 0.1116 - val_mae: 0.2287 Epoch 64/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106514.9141 - loss: 0.0961 - mae: 0.1911 - val_dmae: 127776.3828 - val_loss: 0.1119 - val_mae: 0.2292 Epoch 65/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106053.3594 - loss: 0.0958 - mae: 0.1902 - val_dmae: 126682.3516 - val_loss: 0.1109 - val_mae: 0.2273 Epoch 66/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 105466.3984 - loss: 0.0949 - mae: 0.1892 - val_dmae: 126861.6484 - val_loss: 0.1102 - val_mae: 0.2276 Epoch 67/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107242.6875 - loss: 0.0959 - mae: 0.1924 - val_dmae: 127207.3672 - val_loss: 0.1109 - val_mae: 0.2282 Epoch 68/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107251.6328 - loss: 0.0957 - mae: 0.1924 - val_dmae: 125882.2734 - val_loss: 0.1085 - val_mae: 0.2258 Epoch 69/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107183.1797 - loss: 0.0940 - mae: 0.1923 - val_dmae: 126138.2188 - val_loss: 0.1083 - val_mae: 0.2263 Epoch 70/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106752.0781 - loss: 0.0926 - mae: 0.1915 - val_dmae: 126654.4922 - val_loss: 0.1091 - val_mae: 0.2272 Epoch 71/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105356.8359 - loss: 0.0945 - mae: 0.1890 - val_dmae: 126322.1719 - val_loss: 0.1086 - val_mae: 0.2266 Epoch 72/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105719.0391 - loss: 0.0933 - mae: 0.1896 - val_dmae: 125724.5234 - val_loss: 0.1082 - val_mae: 0.2255 Epoch 73/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104421.3203 - loss: 0.0906 - mae: 0.1873 - val_dmae: 124994.4219 - val_loss: 0.1073 - val_mae: 0.2242 Epoch 74/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103898.9297 - loss: 0.0898 - mae: 0.1864 - val_dmae: 126199.6406 - val_loss: 0.1087 - val_mae: 0.2264 Epoch 75/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105222.2500 - loss: 0.0911 - mae: 0.1888 - val_dmae: 125876.5703 - val_loss: 0.1080 - val_mae: 0.2258 Epoch 76/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104338.0625 - loss: 0.0906 - mae: 0.1872 - val_dmae: 126101.1406 - val_loss: 0.1084 - val_mae: 0.2262 Epoch 77/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 102422.9766 - loss: 0.0876 - mae: 0.1837 - val_dmae: 124594.4141 - val_loss: 0.1065 - val_mae: 0.2235 Epoch 78/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103584.4844 - loss: 0.0885 - mae: 0.1858 - val_dmae: 124104.4922 - val_loss: 0.1053 - val_mae: 0.2226 Epoch 79/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104298.5938 - loss: 0.0881 - mae: 0.1871 - val_dmae: 124902.7422 - val_loss: 0.1058 - val_mae: 0.2241 Epoch 80/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103117.6797 - loss: 0.0886 - mae: 0.1850 - val_dmae: 124353.8672 - val_loss: 0.1056 - val_mae: 0.2231 Epoch 81/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104999.2266 - loss: 0.0891 - mae: 0.1884 - val_dmae: 124839.6094 - val_loss: 0.1061 - val_mae: 0.2239 Epoch 82/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 102493.0156 - loss: 0.0851 - mae: 0.1839 - val_dmae: 125998.2734 - val_loss: 0.1073 - val_mae: 0.2260 Epoch 83/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103112.1172 - loss: 0.0856 - mae: 0.1850 - val_dmae: 124823.8359 - val_loss: 0.1056 - val_mae: 0.2239 Epoch 83: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 72043.2344 - loss: 0.0310 - mae: 0.1292
[0.02758491039276123, 66173.328125, 0.11870791763067245]
# S = 4
model4 = WalmartModel(4, hidden_size=10)
model4.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model4.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 44ms/step - dmae: 541183.5000 - loss: 1.2739 - mae: 0.9708 - val_dmae: 445947.1250 - val_loss: 0.8849 - val_mae: 0.8000 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 523680.4062 - loss: 1.1994 - mae: 0.9394 - val_dmae: 389234.1562 - val_loss: 0.6675 - val_mae: 0.6982 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 410993.5625 - loss: 0.7863 - mae: 0.7373 - val_dmae: 204479.6875 - val_loss: 0.2104 - val_mae: 0.3668 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 140061.5156 - loss: 0.1481 - mae: 0.2513 - val_dmae: 179813.2812 - val_loss: 0.1822 - val_mae: 0.3226 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 123100.4844 - loss: 0.1321 - mae: 0.2208 - val_dmae: 177829.4844 - val_loss: 0.1793 - val_mae: 0.3190 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 121718.7812 - loss: 0.1309 - mae: 0.2184 - val_dmae: 176610.0781 - val_loss: 0.1757 - val_mae: 0.3168 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 122855.4531 - loss: 0.1315 - mae: 0.2204 - val_dmae: 175641.4844 - val_loss: 0.1735 - val_mae: 0.3151 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121785.8672 - loss: 0.1301 - mae: 0.2185 - val_dmae: 175341.3594 - val_loss: 0.1713 - val_mae: 0.3145 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120890.7188 - loss: 0.1289 - mae: 0.2169 - val_dmae: 174643.0781 - val_loss: 0.1695 - val_mae: 0.3133 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121292.8984 - loss: 0.1299 - mae: 0.2176 - val_dmae: 174020.7812 - val_loss: 0.1675 - val_mae: 0.3122 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 120828.5781 - loss: 0.1299 - mae: 0.2168 - val_dmae: 173717.3281 - val_loss: 0.1668 - val_mae: 0.3116 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120860.0391 - loss: 0.1286 - mae: 0.2168 - val_dmae: 173037.5781 - val_loss: 0.1651 - val_mae: 0.3104 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118227.3750 - loss: 0.1263 - mae: 0.2121 - val_dmae: 172579.0781 - val_loss: 0.1642 - val_mae: 0.3096 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120309.1562 - loss: 0.1275 - mae: 0.2158 - val_dmae: 172116.3594 - val_loss: 0.1632 - val_mae: 0.3088 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119245.2422 - loss: 0.1274 - mae: 0.2139 - val_dmae: 171584.2969 - val_loss: 0.1626 - val_mae: 0.3078 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119667.2344 - loss: 0.1261 - mae: 0.2147 - val_dmae: 170852.2812 - val_loss: 0.1609 - val_mae: 0.3065 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118587.4531 - loss: 0.1247 - mae: 0.2127 - val_dmae: 170126.2031 - val_loss: 0.1599 - val_mae: 0.3052 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119761.6016 - loss: 0.1256 - mae: 0.2148 - val_dmae: 169628.2500 - val_loss: 0.1588 - val_mae: 0.3043 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118999.5859 - loss: 0.1254 - mae: 0.2135 - val_dmae: 169241.3125 - val_loss: 0.1578 - val_mae: 0.3036 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120278.4688 - loss: 0.1257 - mae: 0.2158 - val_dmae: 168276.1250 - val_loss: 0.1565 - val_mae: 0.3019 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119082.6094 - loss: 0.1239 - mae: 0.2136 - val_dmae: 167657.5312 - val_loss: 0.1552 - val_mae: 0.3008 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118829.9531 - loss: 0.1235 - mae: 0.2132 - val_dmae: 166965.6875 - val_loss: 0.1547 - val_mae: 0.2995 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117828.6094 - loss: 0.1223 - mae: 0.2114 - val_dmae: 166416.9531 - val_loss: 0.1534 - val_mae: 0.2985 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118497.5547 - loss: 0.1225 - mae: 0.2126 - val_dmae: 165379.1250 - val_loss: 0.1522 - val_mae: 0.2967 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117032.6328 - loss: 0.1208 - mae: 0.2099 - val_dmae: 164990.1406 - val_loss: 0.1516 - val_mae: 0.2960 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118040.7656 - loss: 0.1230 - mae: 0.2118 - val_dmae: 163983.4375 - val_loss: 0.1509 - val_mae: 0.2942 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115586.5312 - loss: 0.1188 - mae: 0.2073 - val_dmae: 163506.5000 - val_loss: 0.1496 - val_mae: 0.2933 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116925.1016 - loss: 0.1226 - mae: 0.2098 - val_dmae: 162691.6562 - val_loss: 0.1489 - val_mae: 0.2919 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117944.3047 - loss: 0.1194 - mae: 0.2116 - val_dmae: 161717.6562 - val_loss: 0.1479 - val_mae: 0.2901 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116896.6641 - loss: 0.1212 - mae: 0.2097 - val_dmae: 160816.7969 - val_loss: 0.1455 - val_mae: 0.2885 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116105.2422 - loss: 0.1179 - mae: 0.2083 - val_dmae: 160315.1094 - val_loss: 0.1462 - val_mae: 0.2876 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116258.7734 - loss: 0.1181 - mae: 0.2086 - val_dmae: 159055.1719 - val_loss: 0.1445 - val_mae: 0.2853 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115883.6797 - loss: 0.1173 - mae: 0.2079 - val_dmae: 158377.6094 - val_loss: 0.1431 - val_mae: 0.2841 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 115552.8828 - loss: 0.1167 - mae: 0.2073 - val_dmae: 157624.4062 - val_loss: 0.1428 - val_mae: 0.2828 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114920.1016 - loss: 0.1162 - mae: 0.2062 - val_dmae: 156869.3438 - val_loss: 0.1410 - val_mae: 0.2814 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 115326.1328 - loss: 0.1162 - mae: 0.2069 - val_dmae: 155840.2656 - val_loss: 0.1408 - val_mae: 0.2796 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114686.2031 - loss: 0.1160 - mae: 0.2057 - val_dmae: 154683.8438 - val_loss: 0.1384 - val_mae: 0.2775 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115514.4844 - loss: 0.1158 - mae: 0.2072 - val_dmae: 153769.5781 - val_loss: 0.1385 - val_mae: 0.2758 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114226.4062 - loss: 0.1153 - mae: 0.2049 - val_dmae: 153012.1719 - val_loss: 0.1362 - val_mae: 0.2745 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113979.2031 - loss: 0.1135 - mae: 0.2045 - val_dmae: 152202.2812 - val_loss: 0.1354 - val_mae: 0.2730 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113728.9062 - loss: 0.1115 - mae: 0.2040 - val_dmae: 151430.1094 - val_loss: 0.1353 - val_mae: 0.2716 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 113138.7812 - loss: 0.1108 - mae: 0.2030 - val_dmae: 150110.1406 - val_loss: 0.1350 - val_mae: 0.2693 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112732.0078 - loss: 0.1121 - mae: 0.2022 - val_dmae: 149602.0312 - val_loss: 0.1334 - val_mae: 0.2684 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112953.9062 - loss: 0.1103 - mae: 0.2026 - val_dmae: 148456.2188 - val_loss: 0.1319 - val_mae: 0.2663 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112995.1875 - loss: 0.1098 - mae: 0.2027 - val_dmae: 147468.7188 - val_loss: 0.1319 - val_mae: 0.2645 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112818.1875 - loss: 0.1119 - mae: 0.2024 - val_dmae: 146326.2031 - val_loss: 0.1303 - val_mae: 0.2625 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111875.6641 - loss: 0.1095 - mae: 0.2007 - val_dmae: 145698.6250 - val_loss: 0.1292 - val_mae: 0.2614 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110506.3906 - loss: 0.1053 - mae: 0.1982 - val_dmae: 144169.7656 - val_loss: 0.1285 - val_mae: 0.2586 Epoch 49/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112231.7969 - loss: 0.1075 - mae: 0.2013 - val_dmae: 143225.8281 - val_loss: 0.1270 - val_mae: 0.2569 Epoch 50/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110201.1484 - loss: 0.1069 - mae: 0.1977 - val_dmae: 142438.1719 - val_loss: 0.1259 - val_mae: 0.2555 Epoch 51/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110309.5938 - loss: 0.1065 - mae: 0.1979 - val_dmae: 141394.1406 - val_loss: 0.1253 - val_mae: 0.2536 Epoch 52/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108985.8203 - loss: 0.1044 - mae: 0.1955 - val_dmae: 140468.4688 - val_loss: 0.1243 - val_mae: 0.2520 Epoch 53/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111539.7656 - loss: 0.1061 - mae: 0.2001 - val_dmae: 139558.6094 - val_loss: 0.1235 - val_mae: 0.2504 Epoch 54/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109181.4688 - loss: 0.1025 - mae: 0.1959 - val_dmae: 138296.8281 - val_loss: 0.1217 - val_mae: 0.2481 Epoch 55/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109968.4609 - loss: 0.1032 - mae: 0.1973 - val_dmae: 137570.2344 - val_loss: 0.1217 - val_mae: 0.2468 Epoch 56/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107895.3750 - loss: 0.0994 - mae: 0.1936 - val_dmae: 136646.6875 - val_loss: 0.1208 - val_mae: 0.2451 Epoch 57/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107892.7422 - loss: 0.1010 - mae: 0.1935 - val_dmae: 136022.4531 - val_loss: 0.1198 - val_mae: 0.2440 Epoch 58/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108493.5859 - loss: 0.1007 - mae: 0.1946 - val_dmae: 134835.5938 - val_loss: 0.1189 - val_mae: 0.2419 Epoch 59/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109186.6172 - loss: 0.1018 - mae: 0.1959 - val_dmae: 134169.0625 - val_loss: 0.1181 - val_mae: 0.2407 Epoch 60/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109062.6641 - loss: 0.1026 - mae: 0.1956 - val_dmae: 133646.3125 - val_loss: 0.1180 - val_mae: 0.2397 Epoch 61/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107113.2734 - loss: 0.0979 - mae: 0.1921 - val_dmae: 132571.2188 - val_loss: 0.1162 - val_mae: 0.2378 Epoch 62/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105384.2266 - loss: 0.0981 - mae: 0.1890 - val_dmae: 131826.2031 - val_loss: 0.1162 - val_mae: 0.2365 Epoch 63/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106014.0703 - loss: 0.0974 - mae: 0.1902 - val_dmae: 131224.7031 - val_loss: 0.1149 - val_mae: 0.2354 Epoch 64/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106483.6250 - loss: 0.0976 - mae: 0.1910 - val_dmae: 130980.0000 - val_loss: 0.1144 - val_mae: 0.2350 Epoch 65/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 105500.1406 - loss: 0.0961 - mae: 0.1893 - val_dmae: 129380.7969 - val_loss: 0.1125 - val_mae: 0.2321 Epoch 66/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107148.6094 - loss: 0.0964 - mae: 0.1922 - val_dmae: 129163.6406 - val_loss: 0.1125 - val_mae: 0.2317 Epoch 67/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106408.5000 - loss: 0.0970 - mae: 0.1909 - val_dmae: 129165.8047 - val_loss: 0.1120 - val_mae: 0.2317 Epoch 68/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104381.0312 - loss: 0.0939 - mae: 0.1872 - val_dmae: 127756.3828 - val_loss: 0.1108 - val_mae: 0.2292 Epoch 69/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105224.0547 - loss: 0.0935 - mae: 0.1888 - val_dmae: 126828.8438 - val_loss: 0.1095 - val_mae: 0.2275 Epoch 70/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104486.1484 - loss: 0.0930 - mae: 0.1874 - val_dmae: 126969.3359 - val_loss: 0.1094 - val_mae: 0.2278 Epoch 71/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 103941.8906 - loss: 0.0921 - mae: 0.1865 - val_dmae: 126879.2422 - val_loss: 0.1091 - val_mae: 0.2276 Epoch 72/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105350.7734 - loss: 0.0933 - mae: 0.1890 - val_dmae: 126084.0000 - val_loss: 0.1085 - val_mae: 0.2262 Epoch 73/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104282.8750 - loss: 0.0908 - mae: 0.1871 - val_dmae: 126411.7500 - val_loss: 0.1079 - val_mae: 0.2268 Epoch 74/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104767.8828 - loss: 0.0938 - mae: 0.1879 - val_dmae: 124989.8672 - val_loss: 0.1067 - val_mae: 0.2242 Epoch 75/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104791.2188 - loss: 0.0916 - mae: 0.1880 - val_dmae: 124792.5938 - val_loss: 0.1059 - val_mae: 0.2239 Epoch 76/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104388.0078 - loss: 0.0911 - mae: 0.1873 - val_dmae: 125464.5000 - val_loss: 0.1064 - val_mae: 0.2251 Epoch 77/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104210.6797 - loss: 0.0914 - mae: 0.1869 - val_dmae: 126268.2969 - val_loss: 0.1070 - val_mae: 0.2265 Epoch 78/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103442.5859 - loss: 0.0917 - mae: 0.1856 - val_dmae: 125105.3750 - val_loss: 0.1058 - val_mae: 0.2244 Epoch 79/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103514.8906 - loss: 0.0897 - mae: 0.1857 - val_dmae: 123832.4219 - val_loss: 0.1045 - val_mae: 0.2221 Epoch 80/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103531.8594 - loss: 0.0920 - mae: 0.1857 - val_dmae: 123028.3438 - val_loss: 0.1036 - val_mae: 0.2207 Epoch 81/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103358.8594 - loss: 0.0879 - mae: 0.1854 - val_dmae: 123414.1641 - val_loss: 0.1035 - val_mae: 0.2214 Epoch 82/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103729.3047 - loss: 0.0902 - mae: 0.1861 - val_dmae: 122685.9375 - val_loss: 0.1031 - val_mae: 0.2201 Epoch 83/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103041.3906 - loss: 0.0902 - mae: 0.1848 - val_dmae: 123704.9766 - val_loss: 0.1038 - val_mae: 0.2219 Epoch 84/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102791.0625 - loss: 0.0872 - mae: 0.1844 - val_dmae: 125045.1719 - val_loss: 0.1050 - val_mae: 0.2243 Epoch 85/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103251.6406 - loss: 0.0866 - mae: 0.1852 - val_dmae: 124456.6250 - val_loss: 0.1044 - val_mae: 0.2233 Epoch 86/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 101863.6719 - loss: 0.0857 - mae: 0.1827 - val_dmae: 123422.0469 - val_loss: 0.1035 - val_mae: 0.2214 Epoch 87/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104849.0078 - loss: 0.0879 - mae: 0.1881 - val_dmae: 124842.6250 - val_loss: 0.1044 - val_mae: 0.2240 Epoch 87: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 78321.1094 - loss: 0.0377 - mae: 0.1405
[0.03053026646375656, 68176.4375, 0.12230127304792404]
The baseline network achieves 64k, 66k and 68k with a sequence length $S$ of 2, 3 and 4, respectively. Theoretically, having more information in the input of the model (e.g. the model with $S=4$ has more information than the first model with $S=2$) should retrieve at least the same results than simpler input representations. In practice, optimizing networks with a lot of input noise (data that has no information to predict the target) with no regularization techniques is extremely hard and would require a considerable amount of data to ensure the generalization of the model. In our case, from the results obtained with the baseline model, it seems that introducing more than $S=2$ past information does not help the network to predict the target outcome, so it should be enough to use sequences of length 2. For the next models we fixed this hyperparameter to $S=2$.
plot_series(model2, [train, val, test], title='Prediction with S=2').show()
plot_series(model3, [train, val, test], title='Prediction with S=3').show()
plot_series(model4, [train, val, test], title='Prediction with S=4').show()
2024-04-11 13:11:38.098054: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:11:43.823147: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:11:50.085394: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 13:11:56.094750: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:02.430470: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:08.287126: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 13:12:13.957533: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:20.489455: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:26.115864: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
To better show the performance disparity between different models we printed the absolute error of the three models in a single plot. Note that the model with $S=2$ gets a higher error (specially in outlier observations) and the model with $S=4$ is the closest to the zero-line, indicating a lower absolute error.
plot_errors([model2, model3, model4], [train, val, test])
2024-04-11 13:12:32.828964: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:38.511505: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:45.043433: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:50.694314: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:12:57.079231: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:02.855903: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:09.154021: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:14.980685: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:20.412501: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Increasing the complexity of the model¶
Once we got our baseline results, we increased the complexity of the model by modifying the hyperparameters of the network (see model.py).
base_layer: The recurrent base cell of the encoder $\mathcal{E}$. There are available two options: the LSTM and the GRU. Althought the LSTM layer is more popular than GRU (specifically in NLP tasks), the GRU has its advantages over the LSTM (e.g. it has less parameters) and it is still used in other DL applications. By default, each base layer is a LSTM.num_encoder_layers($\ell$): Number of layers in the encoder $\mathcal{E}$.num_decoder_layers($\varphi$): Number of layer in the decoder.hidden_size($d_h$): Hidden dimension of the encoder $\mathcal{E}$.regularizer: Kernel and bias regularizer in the hidden layers. By default there is no regularization.initializer: Weight initialization. All biases are initialized from zero. By default, kernels are initialized following a random normal distribution.bidirectional: Whether to process left-to-right and right-to-left the input sequence or only left-to-right. By default, the processing is bidirectional, so the left-to-right and righ-to-left information is concatenated to produce a unique contextualization of each timestep observation.dropout: Dropout value in the latent space of the neural architecture.
The next cell code increaes the dimensionality of the model upon $d_h=50$ and uses 3 layers in the encoder (maintaining 2 layers in the decoder). We tested two different architectures: the first one uses the LSTM cell again and the second replaces the LSTM layer by the GRU cell.
model_lstm = WalmartModel(2, hidden_size=50, num_encoder_layers=3)
model_lstm.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model_lstm.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 7s 55ms/step - dmae: 541028.4375 - loss: 1.2628 - mae: 0.9705 - val_dmae: 425864.0312 - val_loss: 0.8837 - val_mae: 0.7640 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 356509.9062 - loss: 0.6480 - mae: 0.6395 - val_dmae: 177930.2500 - val_loss: 0.2288 - val_mae: 0.3192 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 116861.7734 - loss: 0.1174 - mae: 0.2096 - val_dmae: 167336.6250 - val_loss: 0.2143 - val_mae: 0.3002 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 114137.8984 - loss: 0.1138 - mae: 0.2048 - val_dmae: 168187.6406 - val_loss: 0.2091 - val_mae: 0.3017 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 114569.1953 - loss: 0.1126 - mae: 0.2055 - val_dmae: 168283.5312 - val_loss: 0.2067 - val_mae: 0.3019 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 116331.8828 - loss: 0.1143 - mae: 0.2087 - val_dmae: 168518.9375 - val_loss: 0.2040 - val_mae: 0.3023 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 118146.3594 - loss: 0.1152 - mae: 0.2119 - val_dmae: 167939.1719 - val_loss: 0.2013 - val_mae: 0.3013 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 116887.8828 - loss: 0.1115 - mae: 0.2097 - val_dmae: 167847.4062 - val_loss: 0.1986 - val_mae: 0.3011 Epoch 8: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 78822.9453 - loss: 0.0351 - mae: 0.1414
[0.03196907415986061, 73991.0546875, 0.1327320635318756]
model_gru = WalmartModel(2, base_layer=GRU, hidden_size=50, num_encoder_layers=3)
model_gru.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model_gru.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 7s 53ms/step - dmae: 525163.6875 - loss: 1.2022 - mae: 0.9421 - val_dmae: 189033.2812 - val_loss: 0.2542 - val_mae: 0.3391 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 171360.1094 - loss: 0.1827 - mae: 0.3074 - val_dmae: 190236.4844 - val_loss: 0.2677 - val_mae: 0.3413 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 109318.0156 - loss: 0.1113 - mae: 0.1961 - val_dmae: 168021.2969 - val_loss: 0.2158 - val_mae: 0.3014 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 109797.6641 - loss: 0.1098 - mae: 0.1970 - val_dmae: 165004.7188 - val_loss: 0.2051 - val_mae: 0.2960 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 109416.9141 - loss: 0.1070 - mae: 0.1963 - val_dmae: 164106.9375 - val_loss: 0.1972 - val_mae: 0.2944 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 110591.1719 - loss: 0.1061 - mae: 0.1984 - val_dmae: 163802.1250 - val_loss: 0.1911 - val_mae: 0.2938 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 112792.9453 - loss: 0.1058 - mae: 0.2023 - val_dmae: 163523.4375 - val_loss: 0.1854 - val_mae: 0.2933 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 111071.8203 - loss: 0.1022 - mae: 0.1993 - val_dmae: 162903.8750 - val_loss: 0.1811 - val_mae: 0.2922 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 112165.5625 - loss: 0.1027 - mae: 0.2012 - val_dmae: 162133.5625 - val_loss: 0.1781 - val_mae: 0.2909 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 112104.1328 - loss: 0.1014 - mae: 0.2011 - val_dmae: 161738.9062 - val_loss: 0.1760 - val_mae: 0.2901 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 112796.5078 - loss: 0.1004 - mae: 0.2023 - val_dmae: 159849.2812 - val_loss: 0.1715 - val_mae: 0.2868 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 111312.7969 - loss: 0.0995 - mae: 0.1997 - val_dmae: 159882.1406 - val_loss: 0.1706 - val_mae: 0.2868 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 111637.8594 - loss: 0.0989 - mae: 0.2003 - val_dmae: 159171.9844 - val_loss: 0.1684 - val_mae: 0.2855 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 110569.5859 - loss: 0.0970 - mae: 0.1984 - val_dmae: 157746.8281 - val_loss: 0.1655 - val_mae: 0.2830 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 109712.0547 - loss: 0.0950 - mae: 0.1968 - val_dmae: 157215.9844 - val_loss: 0.1639 - val_mae: 0.2820 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 109980.0234 - loss: 0.0952 - mae: 0.1973 - val_dmae: 155647.9062 - val_loss: 0.1614 - val_mae: 0.2792 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 108776.5781 - loss: 0.0916 - mae: 0.1951 - val_dmae: 153423.2344 - val_loss: 0.1576 - val_mae: 0.2752 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 108486.1719 - loss: 0.0917 - mae: 0.1946 - val_dmae: 154014.7031 - val_loss: 0.1574 - val_mae: 0.2763 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 107799.8594 - loss: 0.0901 - mae: 0.1934 - val_dmae: 152414.8906 - val_loss: 0.1540 - val_mae: 0.2734 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 106329.3672 - loss: 0.0884 - mae: 0.1907 - val_dmae: 151545.2188 - val_loss: 0.1508 - val_mae: 0.2719 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 105966.0391 - loss: 0.0869 - mae: 0.1901 - val_dmae: 149537.7188 - val_loss: 0.1474 - val_mae: 0.2683 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 104388.3047 - loss: 0.0846 - mae: 0.1873 - val_dmae: 149337.6406 - val_loss: 0.1454 - val_mae: 0.2679 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 103674.5234 - loss: 0.0836 - mae: 0.1860 - val_dmae: 147303.0625 - val_loss: 0.1416 - val_mae: 0.2642 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 104319.8047 - loss: 0.0829 - mae: 0.1871 - val_dmae: 145766.2188 - val_loss: 0.1394 - val_mae: 0.2615 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 104076.4453 - loss: 0.0818 - mae: 0.1867 - val_dmae: 146061.2500 - val_loss: 0.1388 - val_mae: 0.2620 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 101932.6719 - loss: 0.0803 - mae: 0.1829 - val_dmae: 144234.1719 - val_loss: 0.1358 - val_mae: 0.2587 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 102819.5078 - loss: 0.0808 - mae: 0.1844 - val_dmae: 143184.4844 - val_loss: 0.1350 - val_mae: 0.2569 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 100399.2109 - loss: 0.0767 - mae: 0.1801 - val_dmae: 142360.8281 - val_loss: 0.1328 - val_mae: 0.2554 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 103223.0391 - loss: 0.0802 - mae: 0.1852 - val_dmae: 139313.1094 - val_loss: 0.1286 - val_mae: 0.2499 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 102348.4609 - loss: 0.0784 - mae: 0.1836 - val_dmae: 139132.4219 - val_loss: 0.1271 - val_mae: 0.2496 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 102387.0234 - loss: 0.0779 - mae: 0.1837 - val_dmae: 138350.6875 - val_loss: 0.1258 - val_mae: 0.2482 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 100219.6094 - loss: 0.0758 - mae: 0.1798 - val_dmae: 138669.4375 - val_loss: 0.1257 - val_mae: 0.2488 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 100385.4297 - loss: 0.0747 - mae: 0.1801 - val_dmae: 138432.5469 - val_loss: 0.1242 - val_mae: 0.2483 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 99647.8203 - loss: 0.0745 - mae: 0.1788 - val_dmae: 136882.0156 - val_loss: 0.1215 - val_mae: 0.2456 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 99347.9141 - loss: 0.0735 - mae: 0.1782 - val_dmae: 135133.5781 - val_loss: 0.1195 - val_mae: 0.2424 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 98495.4219 - loss: 0.0724 - mae: 0.1767 - val_dmae: 136209.8594 - val_loss: 0.1211 - val_mae: 0.2443 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 97213.7422 - loss: 0.0705 - mae: 0.1744 - val_dmae: 134411.0781 - val_loss: 0.1173 - val_mae: 0.2411 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 48ms/step - dmae: 97048.8359 - loss: 0.0699 - mae: 0.1741 - val_dmae: 133364.7656 - val_loss: 0.1166 - val_mae: 0.2392 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 97005.8438 - loss: 0.0709 - mae: 0.1740 - val_dmae: 132827.4688 - val_loss: 0.1148 - val_mae: 0.2383 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 97565.8516 - loss: 0.0707 - mae: 0.1750 - val_dmae: 133914.7031 - val_loss: 0.1168 - val_mae: 0.2402 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 97036.3672 - loss: 0.0704 - mae: 0.1741 - val_dmae: 133973.3125 - val_loss: 0.1152 - val_mae: 0.2403 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 96758.5781 - loss: 0.0687 - mae: 0.1736 - val_dmae: 133312.0625 - val_loss: 0.1144 - val_mae: 0.2391 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - dmae: 98092.6172 - loss: 0.0705 - mae: 0.1760 - val_dmae: 130319.1562 - val_loss: 0.1114 - val_mae: 0.2338 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 95912.1953 - loss: 0.0677 - mae: 0.1721 - val_dmae: 131547.6094 - val_loss: 0.1137 - val_mae: 0.2360 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 95580.4609 - loss: 0.0667 - mae: 0.1715 - val_dmae: 131305.7500 - val_loss: 0.1124 - val_mae: 0.2355 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 96280.3750 - loss: 0.0685 - mae: 0.1727 - val_dmae: 133229.5938 - val_loss: 0.1140 - val_mae: 0.2390 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 97576.2266 - loss: 0.0692 - mae: 0.1750 - val_dmae: 130656.0078 - val_loss: 0.1115 - val_mae: 0.2344 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 96363.0391 - loss: 0.0674 - mae: 0.1729 - val_dmae: 130564.7422 - val_loss: 0.1097 - val_mae: 0.2342 Epoch 48: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 72603.0781 - loss: 0.0325 - mae: 0.1302
[0.027373870834708214, 64881.5390625, 0.11639059334993362]
We see that the results with the LSTM cell (73k) are worse than those with the GRU (64k). In the original paper of the LSTM (Hochreiter and Schmidhuber (1997)) authors describe that the LSTM cell is able to better contextualize longer sequences thanks to the three gates that control which information is maintained and forgotten. In other fields where sequences are longer (e.g. in NLP where we expect sentences to be conformed by 10-20 words) the LSTMs considerably outperform GRUs. In this dataset, since the sequence length is fixed to $S=2$, we see no significant diference between this two cells. The GRU seems to retrieve slightly better results (probably due to the lower number of parameters and then the less bias to overfitting) than the LSTM.
As a final improvement in our architecture, we can enable the option of bidirectional processing in the recurrent layers. The bidirectional processing consists of learning the recurrent information of an input sequence from left-to-right and right-to-left, and concatenating the hidden contextualizations to return a new sequence contextualization with past and future information. The bidirectionality has demonstrated a considerable improvement in recurrent layers since it allows the network to contextualize current information with future observations.
The next cell executes a bidirectional GRU-based encoder with 2-stacked FFNs in the decoder, training with a smaller learning rate to help the network to smoothly optimize its weights.
model_lstm = WalmartModel(2, base_layer=GRU, hidden_size=50, num_encoder_layers=3, dropout=0.1, bidirectional=True)
model_lstm.train(train, val, 'results/walmart3.weights.h5', Adam(5e-4), batch_size=BATCH_SIZE)
model_lstm.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 11s 61ms/step - dmae: 511694.0312 - loss: 1.1551 - mae: 0.9179 - val_dmae: 189024.5156 - val_loss: 0.2542 - val_mae: 0.3391 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 142497.4375 - loss: 0.1530 - mae: 0.2556 - val_dmae: 186431.4219 - val_loss: 0.2618 - val_mae: 0.3344 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 108881.1016 - loss: 0.1089 - mae: 0.1953 - val_dmae: 174624.7188 - val_loss: 0.2308 - val_mae: 0.3133 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107733.7578 - loss: 0.1084 - mae: 0.1933 - val_dmae: 172770.8125 - val_loss: 0.2240 - val_mae: 0.3099 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107267.8594 - loss: 0.1057 - mae: 0.1924 - val_dmae: 171325.6094 - val_loss: 0.2185 - val_mae: 0.3073 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108637.7891 - loss: 0.1061 - mae: 0.1949 - val_dmae: 169672.3750 - val_loss: 0.2139 - val_mae: 0.3044 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108563.2109 - loss: 0.1049 - mae: 0.1948 - val_dmae: 168942.3594 - val_loss: 0.2103 - val_mae: 0.3031 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 109264.6172 - loss: 0.1058 - mae: 0.1960 - val_dmae: 167967.6250 - val_loss: 0.2068 - val_mae: 0.3013 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109888.1797 - loss: 0.1050 - mae: 0.1971 - val_dmae: 167338.2500 - val_loss: 0.2035 - val_mae: 0.3002 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 110322.7031 - loss: 0.1040 - mae: 0.1979 - val_dmae: 166839.9688 - val_loss: 0.2007 - val_mae: 0.2993 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 111002.6016 - loss: 0.1040 - mae: 0.1991 - val_dmae: 166145.4375 - val_loss: 0.1969 - val_mae: 0.2980 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111035.5859 - loss: 0.1032 - mae: 0.1992 - val_dmae: 165776.1094 - val_loss: 0.1946 - val_mae: 0.2974 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111394.9844 - loss: 0.1029 - mae: 0.1998 - val_dmae: 165218.1719 - val_loss: 0.1916 - val_mae: 0.2964 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111347.3828 - loss: 0.1022 - mae: 0.1997 - val_dmae: 164132.3594 - val_loss: 0.1886 - val_mae: 0.2944 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110852.0625 - loss: 0.1001 - mae: 0.1989 - val_dmae: 164826.7031 - val_loss: 0.1868 - val_mae: 0.2957 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111642.8672 - loss: 0.1013 - mae: 0.2003 - val_dmae: 163517.5781 - val_loss: 0.1834 - val_mae: 0.2933 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110185.2578 - loss: 0.0988 - mae: 0.1977 - val_dmae: 162725.9375 - val_loss: 0.1807 - val_mae: 0.2919 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110214.9531 - loss: 0.0973 - mae: 0.1977 - val_dmae: 162752.4219 - val_loss: 0.1786 - val_mae: 0.2920 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 110505.2734 - loss: 0.0977 - mae: 0.1982 - val_dmae: 161590.3594 - val_loss: 0.1752 - val_mae: 0.2899 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 110159.6094 - loss: 0.0953 - mae: 0.1976 - val_dmae: 160823.5000 - val_loss: 0.1723 - val_mae: 0.2885 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109186.1094 - loss: 0.0936 - mae: 0.1959 - val_dmae: 159566.3281 - val_loss: 0.1683 - val_mae: 0.2862 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 108622.1641 - loss: 0.0918 - mae: 0.1949 - val_dmae: 158404.7500 - val_loss: 0.1638 - val_mae: 0.2842 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108196.7969 - loss: 0.0905 - mae: 0.1941 - val_dmae: 158937.7656 - val_loss: 0.1609 - val_mae: 0.2851 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107088.9844 - loss: 0.0880 - mae: 0.1921 - val_dmae: 158154.9062 - val_loss: 0.1579 - val_mae: 0.2837 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 105292.7422 - loss: 0.0851 - mae: 0.1889 - val_dmae: 157132.9062 - val_loss: 0.1550 - val_mae: 0.2819 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 105309.1875 - loss: 0.0841 - mae: 0.1889 - val_dmae: 156571.4375 - val_loss: 0.1535 - val_mae: 0.2809 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103167.1172 - loss: 0.0806 - mae: 0.1851 - val_dmae: 151948.5938 - val_loss: 0.1486 - val_mae: 0.2726 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100862.3594 - loss: 0.0779 - mae: 0.1809 - val_dmae: 152059.4062 - val_loss: 0.1477 - val_mae: 0.2728 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 100803.9766 - loss: 0.0778 - mae: 0.1808 - val_dmae: 150152.9688 - val_loss: 0.1442 - val_mae: 0.2694 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 98984.7422 - loss: 0.0757 - mae: 0.1776 - val_dmae: 148794.4375 - val_loss: 0.1423 - val_mae: 0.2669 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 97321.9062 - loss: 0.0732 - mae: 0.1746 - val_dmae: 148345.7031 - val_loss: 0.1397 - val_mae: 0.2661 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 97453.7578 - loss: 0.0725 - mae: 0.1748 - val_dmae: 146179.5938 - val_loss: 0.1370 - val_mae: 0.2622 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 96209.7656 - loss: 0.0715 - mae: 0.1726 - val_dmae: 145746.8906 - val_loss: 0.1345 - val_mae: 0.2615 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 96604.7344 - loss: 0.0714 - mae: 0.1733 - val_dmae: 143477.1875 - val_loss: 0.1311 - val_mae: 0.2574 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 96501.2734 - loss: 0.0699 - mae: 0.1731 - val_dmae: 143601.0781 - val_loss: 0.1310 - val_mae: 0.2576 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 95101.6719 - loss: 0.0683 - mae: 0.1706 - val_dmae: 142305.2656 - val_loss: 0.1277 - val_mae: 0.2553 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95438.9453 - loss: 0.0676 - mae: 0.1712 - val_dmae: 142679.8125 - val_loss: 0.1265 - val_mae: 0.2560 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 94747.6562 - loss: 0.0663 - mae: 0.1700 - val_dmae: 137811.0156 - val_loss: 0.1201 - val_mae: 0.2472 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93081.6562 - loss: 0.0651 - mae: 0.1670 - val_dmae: 139094.8281 - val_loss: 0.1209 - val_mae: 0.2495 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 94390.5391 - loss: 0.0659 - mae: 0.1693 - val_dmae: 135983.8594 - val_loss: 0.1169 - val_mae: 0.2439 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92927.3203 - loss: 0.0644 - mae: 0.1667 - val_dmae: 140755.0312 - val_loss: 0.1227 - val_mae: 0.2525 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 94259.3516 - loss: 0.0654 - mae: 0.1691 - val_dmae: 133848.7500 - val_loss: 0.1140 - val_mae: 0.2401 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93657.9922 - loss: 0.0650 - mae: 0.1680 - val_dmae: 137869.3594 - val_loss: 0.1192 - val_mae: 0.2473 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 92605.9922 - loss: 0.0634 - mae: 0.1661 - val_dmae: 134989.3438 - val_loss: 0.1148 - val_mae: 0.2422 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92207.6094 - loss: 0.0634 - mae: 0.1654 - val_dmae: 135320.2812 - val_loss: 0.1154 - val_mae: 0.2428 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 93152.2578 - loss: 0.0638 - mae: 0.1671 - val_dmae: 130911.6562 - val_loss: 0.1084 - val_mae: 0.2348 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92816.6250 - loss: 0.0637 - mae: 0.1665 - val_dmae: 131079.3750 - val_loss: 0.1083 - val_mae: 0.2351 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 91563.8828 - loss: 0.0619 - mae: 0.1643 - val_dmae: 135479.6406 - val_loss: 0.1136 - val_mae: 0.2430 Epoch 49/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 91637.3594 - loss: 0.0623 - mae: 0.1644 - val_dmae: 135391.9844 - val_loss: 0.1133 - val_mae: 0.2429 Epoch 50/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 91618.4531 - loss: 0.0623 - mae: 0.1644 - val_dmae: 139915.0625 - val_loss: 0.1206 - val_mae: 0.2510 Epoch 51/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 92722.4375 - loss: 0.0629 - mae: 0.1663 - val_dmae: 134980.9062 - val_loss: 0.1145 - val_mae: 0.2421 Epoch 51: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - dmae: 69979.5078 - loss: 0.0293 - mae: 0.1255
[0.02632186934351921, 64908.85546875, 0.1164395734667778]
The bidirectional GRU achieves a better MAE than the LSTM-based model but does not outperform the unidirectional GRU. The main reason of why bidirectionality does not improve baseline results is explained again by (1) the sequence length of the data stream and (2) the information collected in our dataset. Bidirectionality usually works well with longer sequences where future information is important to give a meaning to the whole sequence (e.g. in natural language). In this case, the sequences are shorter ($S=2$) and left-to-right processing makes more sense than bidirectionality since in the real environment the data stream is also generated from left to right and the target is always a future outcome from the previous input observations.
As a final improvement of our network, now that it has been demonstrated that the unidirectional GRU cell is the best option as the base recurrent layer, we increased again the sequence length to $S=3$ to see if the architecture can be further improved:
model = WalmartModel(3, base_layer=GRU, hidden_size=50, num_encoder_layers=3, activation='relu')
model.train(train, val, 'results/walmart3.weights.h5', Adam(1e-4), batch_size=BATCH_SIZE)
model.evaluate(test)
Epoch 1/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 6s 48ms/step - dmae: 540893.8750 - loss: 1.2754 - mae: 0.9703 - val_dmae: 449361.0312 - val_loss: 0.9042 - val_mae: 0.8061 Epoch 2/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 538479.1875 - loss: 1.2653 - mae: 0.9660 - val_dmae: 447091.1875 - val_loss: 0.8951 - val_mae: 0.8020 Epoch 3/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 535209.2500 - loss: 1.2513 - mae: 0.9601 - val_dmae: 443291.3125 - val_loss: 0.8800 - val_mae: 0.7952 Epoch 4/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 529381.6875 - loss: 1.2268 - mae: 0.9497 - val_dmae: 436300.8125 - val_loss: 0.8528 - val_mae: 0.7827 Epoch 5/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 518687.8125 - loss: 1.1826 - mae: 0.9305 - val_dmae: 423372.2500 - val_loss: 0.8035 - val_mae: 0.7595 Epoch 6/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 499458.9688 - loss: 1.1044 - mae: 0.8960 - val_dmae: 401133.8438 - val_loss: 0.7227 - val_mae: 0.7196 Epoch 7/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 466482.4062 - loss: 0.9773 - mae: 0.8368 - val_dmae: 364942.8438 - val_loss: 0.6002 - val_mae: 0.6547 Epoch 8/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 411903.0312 - loss: 0.7847 - mae: 0.7389 - val_dmae: 308532.4062 - val_loss: 0.4345 - val_mae: 0.5535 Epoch 9/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 328275.8438 - loss: 0.5317 - mae: 0.5889 - val_dmae: 236963.0156 - val_loss: 0.2656 - val_mae: 0.4251 Epoch 10/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 224794.2812 - loss: 0.2873 - mae: 0.4033 - val_dmae: 190205.9688 - val_loss: 0.1850 - val_mae: 0.3412 Epoch 11/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 149743.4062 - loss: 0.1608 - mae: 0.2686 - val_dmae: 178606.9844 - val_loss: 0.1855 - val_mae: 0.3204 Epoch 12/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 132380.6406 - loss: 0.1441 - mae: 0.2375 - val_dmae: 174568.6406 - val_loss: 0.1837 - val_mae: 0.3132 Epoch 13/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 128074.0703 - loss: 0.1378 - mae: 0.2298 - val_dmae: 171316.6250 - val_loss: 0.1807 - val_mae: 0.3073 Epoch 14/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 123321.8047 - loss: 0.1309 - mae: 0.2212 - val_dmae: 168996.5781 - val_loss: 0.1795 - val_mae: 0.3032 Epoch 15/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121276.4062 - loss: 0.1307 - mae: 0.2176 - val_dmae: 167217.9844 - val_loss: 0.1792 - val_mae: 0.3000 Epoch 16/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119898.6719 - loss: 0.1266 - mae: 0.2151 - val_dmae: 165710.2500 - val_loss: 0.1779 - val_mae: 0.2973 Epoch 17/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 120748.2500 - loss: 0.1311 - mae: 0.2166 - val_dmae: 164155.7812 - val_loss: 0.1751 - val_mae: 0.2945 Epoch 18/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118620.0938 - loss: 0.1283 - mae: 0.2128 - val_dmae: 163283.8750 - val_loss: 0.1744 - val_mae: 0.2929 Epoch 19/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119612.2266 - loss: 0.1292 - mae: 0.2146 - val_dmae: 162294.1719 - val_loss: 0.1721 - val_mae: 0.2911 Epoch 20/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117537.4141 - loss: 0.1259 - mae: 0.2108 - val_dmae: 161565.1562 - val_loss: 0.1715 - val_mae: 0.2898 Epoch 21/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118515.7188 - loss: 0.1237 - mae: 0.2126 - val_dmae: 161125.4375 - val_loss: 0.1707 - val_mae: 0.2890 Epoch 22/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116537.5859 - loss: 0.1236 - mae: 0.2091 - val_dmae: 160738.9531 - val_loss: 0.1703 - val_mae: 0.2883 Epoch 23/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115967.6094 - loss: 0.1221 - mae: 0.2080 - val_dmae: 159929.4531 - val_loss: 0.1678 - val_mae: 0.2869 Epoch 24/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116750.2656 - loss: 0.1233 - mae: 0.2094 - val_dmae: 159184.4531 - val_loss: 0.1662 - val_mae: 0.2856 Epoch 25/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114966.4922 - loss: 0.1223 - mae: 0.2062 - val_dmae: 158413.7344 - val_loss: 0.1646 - val_mae: 0.2842 Epoch 26/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114064.5000 - loss: 0.1196 - mae: 0.2046 - val_dmae: 157375.3125 - val_loss: 0.1621 - val_mae: 0.2823 Epoch 27/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114311.6562 - loss: 0.1190 - mae: 0.2051 - val_dmae: 156778.7031 - val_loss: 0.1611 - val_mae: 0.2812 Epoch 28/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114478.1328 - loss: 0.1203 - mae: 0.2054 - val_dmae: 155944.7656 - val_loss: 0.1593 - val_mae: 0.2797 Epoch 29/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113959.6016 - loss: 0.1200 - mae: 0.2044 - val_dmae: 154698.1406 - val_loss: 0.1567 - val_mae: 0.2775 Epoch 30/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112983.5547 - loss: 0.1161 - mae: 0.2027 - val_dmae: 154906.0625 - val_loss: 0.1580 - val_mae: 0.2779 Epoch 31/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113660.1641 - loss: 0.1153 - mae: 0.2039 - val_dmae: 153222.2500 - val_loss: 0.1539 - val_mae: 0.2749 Epoch 32/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112292.4141 - loss: 0.1149 - mae: 0.2014 - val_dmae: 153329.9375 - val_loss: 0.1551 - val_mae: 0.2751 Epoch 33/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111415.6875 - loss: 0.1130 - mae: 0.1999 - val_dmae: 151707.0312 - val_loss: 0.1515 - val_mae: 0.2721 Epoch 34/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111281.8516 - loss: 0.1122 - mae: 0.1996 - val_dmae: 151017.2188 - val_loss: 0.1500 - val_mae: 0.2709 Epoch 35/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112849.5625 - loss: 0.1135 - mae: 0.2024 - val_dmae: 149998.5000 - val_loss: 0.1486 - val_mae: 0.2691 Epoch 36/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110330.5859 - loss: 0.1082 - mae: 0.1979 - val_dmae: 149588.4531 - val_loss: 0.1481 - val_mae: 0.2683 Epoch 37/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110554.2500 - loss: 0.1129 - mae: 0.1983 - val_dmae: 149249.8750 - val_loss: 0.1482 - val_mae: 0.2677 Epoch 38/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110144.3750 - loss: 0.1089 - mae: 0.1976 - val_dmae: 147686.2344 - val_loss: 0.1454 - val_mae: 0.2649 Epoch 39/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110666.6172 - loss: 0.1090 - mae: 0.1985 - val_dmae: 146978.9062 - val_loss: 0.1445 - val_mae: 0.2637 Epoch 40/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109784.4375 - loss: 0.1084 - mae: 0.1969 - val_dmae: 146587.7031 - val_loss: 0.1440 - val_mae: 0.2630 Epoch 41/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109904.5234 - loss: 0.1099 - mae: 0.1972 - val_dmae: 145439.9844 - val_loss: 0.1421 - val_mae: 0.2609 Epoch 42/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110131.3125 - loss: 0.1081 - mae: 0.1976 - val_dmae: 145447.8438 - val_loss: 0.1427 - val_mae: 0.2609 Epoch 43/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109797.9609 - loss: 0.1075 - mae: 0.1970 - val_dmae: 144671.6094 - val_loss: 0.1416 - val_mae: 0.2595 Epoch 44/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108804.9453 - loss: 0.1065 - mae: 0.1952 - val_dmae: 144289.3281 - val_loss: 0.1413 - val_mae: 0.2588 Epoch 45/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108217.3594 - loss: 0.1064 - mae: 0.1941 - val_dmae: 143205.9688 - val_loss: 0.1390 - val_mae: 0.2569 Epoch 46/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108837.2109 - loss: 0.1069 - mae: 0.1952 - val_dmae: 144665.3281 - val_loss: 0.1422 - val_mae: 0.2595 Epoch 47/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107000.0391 - loss: 0.1057 - mae: 0.1919 - val_dmae: 142787.6250 - val_loss: 0.1388 - val_mae: 0.2561 Epoch 48/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108004.4297 - loss: 0.1049 - mae: 0.1937 - val_dmae: 142663.8125 - val_loss: 0.1387 - val_mae: 0.2559 Epoch 49/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 108662.1172 - loss: 0.1064 - mae: 0.1949 - val_dmae: 142986.2656 - val_loss: 0.1394 - val_mae: 0.2565 Epoch 50/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106500.9219 - loss: 0.1055 - mae: 0.1911 - val_dmae: 142966.2031 - val_loss: 0.1394 - val_mae: 0.2565 Epoch 51/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108066.3672 - loss: 0.1040 - mae: 0.1939 - val_dmae: 141710.1406 - val_loss: 0.1377 - val_mae: 0.2542 Epoch 52/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108070.0625 - loss: 0.1044 - mae: 0.1939 - val_dmae: 142136.1719 - val_loss: 0.1380 - val_mae: 0.2550 Epoch 53/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108798.0391 - loss: 0.1051 - mae: 0.1952 - val_dmae: 141703.1719 - val_loss: 0.1376 - val_mae: 0.2542 Epoch 54/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108090.4453 - loss: 0.1038 - mae: 0.1939 - val_dmae: 142201.6094 - val_loss: 0.1383 - val_mae: 0.2551 Epoch 55/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107303.2500 - loss: 0.1035 - mae: 0.1925 - val_dmae: 141007.3281 - val_loss: 0.1370 - val_mae: 0.2530 Epoch 56/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106790.6719 - loss: 0.1030 - mae: 0.1916 - val_dmae: 140742.4375 - val_loss: 0.1356 - val_mae: 0.2525 Epoch 57/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108648.7656 - loss: 0.1031 - mae: 0.1949 - val_dmae: 140455.1562 - val_loss: 0.1362 - val_mae: 0.2520 Epoch 58/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105881.9219 - loss: 0.1018 - mae: 0.1899 - val_dmae: 139839.9688 - val_loss: 0.1347 - val_mae: 0.2509 Epoch 59/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107379.3359 - loss: 0.1027 - mae: 0.1926 - val_dmae: 140462.4531 - val_loss: 0.1358 - val_mae: 0.2520 Epoch 60/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106281.9766 - loss: 0.1009 - mae: 0.1907 - val_dmae: 140288.6875 - val_loss: 0.1358 - val_mae: 0.2517 Epoch 61/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107474.5000 - loss: 0.1025 - mae: 0.1928 - val_dmae: 139368.0156 - val_loss: 0.1339 - val_mae: 0.2500 Epoch 62/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105230.5703 - loss: 0.1008 - mae: 0.1888 - val_dmae: 140052.0312 - val_loss: 0.1355 - val_mae: 0.2512 Epoch 63/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107418.7188 - loss: 0.1024 - mae: 0.1927 - val_dmae: 139620.7500 - val_loss: 0.1344 - val_mae: 0.2505 Epoch 64/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106750.2500 - loss: 0.1021 - mae: 0.1915 - val_dmae: 139406.6094 - val_loss: 0.1340 - val_mae: 0.2501 Epoch 65/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105582.7500 - loss: 0.1015 - mae: 0.1894 - val_dmae: 139395.6094 - val_loss: 0.1343 - val_mae: 0.2501 Epoch 66/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105239.0000 - loss: 0.0998 - mae: 0.1888 - val_dmae: 138382.9062 - val_loss: 0.1332 - val_mae: 0.2482 Epoch 67/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106342.9844 - loss: 0.1037 - mae: 0.1908 - val_dmae: 138564.4375 - val_loss: 0.1330 - val_mae: 0.2486 Epoch 68/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105170.2578 - loss: 0.1006 - mae: 0.1887 - val_dmae: 137998.2656 - val_loss: 0.1322 - val_mae: 0.2476 Epoch 69/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104818.1641 - loss: 0.0981 - mae: 0.1880 - val_dmae: 139308.3438 - val_loss: 0.1347 - val_mae: 0.2499 Epoch 70/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106699.0078 - loss: 0.1004 - mae: 0.1914 - val_dmae: 139299.6562 - val_loss: 0.1341 - val_mae: 0.2499 Epoch 71/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105436.6016 - loss: 0.0982 - mae: 0.1891 - val_dmae: 138671.7500 - val_loss: 0.1336 - val_mae: 0.2488 Epoch 72/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106009.1016 - loss: 0.0989 - mae: 0.1902 - val_dmae: 139040.2031 - val_loss: 0.1338 - val_mae: 0.2494 Epoch 73/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107466.4688 - loss: 0.1009 - mae: 0.1928 - val_dmae: 137619.5469 - val_loss: 0.1307 - val_mae: 0.2469 Epoch 74/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107295.4922 - loss: 0.1017 - mae: 0.1925 - val_dmae: 138153.8125 - val_loss: 0.1321 - val_mae: 0.2478 Epoch 75/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103870.9141 - loss: 0.0978 - mae: 0.1863 - val_dmae: 139248.2969 - val_loss: 0.1338 - val_mae: 0.2498 Epoch 76/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105554.5938 - loss: 0.1001 - mae: 0.1894 - val_dmae: 138325.0938 - val_loss: 0.1323 - val_mae: 0.2481 Epoch 77/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105591.8828 - loss: 0.0984 - mae: 0.1894 - val_dmae: 139266.5000 - val_loss: 0.1342 - val_mae: 0.2498 Epoch 78/2000 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105442.5234 - loss: 0.0995 - mae: 0.1892 - val_dmae: 137623.3750 - val_loss: 0.1314 - val_mae: 0.2469 Epoch 78: early stopping 45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - dmae: 63942.0352 - loss: 0.0242 - mae: 0.1147
[0.02329966053366661, 61717.328125, 0.11071430891752243]
plot_series(model, [train, val, test], title='Predictions with S=3').show()
2024-04-11 13:13:27.407163: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:33.286880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence 2024-04-11 13:13:39.721647: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
By increasing the $S$ value we obtain the best MAE result of 61k points.
Regularization hyperparameters¶
There are other hyperparameters of the network with less explicability than the previous explained configurations (bidirectionality, GRU vs LSTM, dropout, dimension of the model, etc.). For those hyperparameters we prepared a grid search to obtain the best configuration.
grid = OrderedDict(
regularizer = [L1(1e-3), L2(1e-3), L1L2(1e-3)],
initializer=['random_normal', 'glorot_uniform'],
activation=['tanh', 'relu']
)
def applydeep(lists, func):
result = []
for item in lists:
result.append(list(map(func, item)))
return result
df = pd.DataFrame(columns=['train', 'val', 'test'], index=pd.MultiIndex.from_product(applydeep(grid.values(), str)))
for i, params in enumerate(product(*grid.values())):
params = dict(zip(grid.keys(), params))
model = WalmartModel(seq_len=3, base_layer=GRU, num_encoder_layers=3, num_decoder_layers=2, bidirectional=False, dropout=0.1,**params)
model.train(train, test, f'results/walmart.weights.h5', Adam(1e-4), batch_size=BATCH_SIZE)
(_, train_mae, _), (_, val_mae, _), (_, test_mae, _) = map(model.evaluate, (train, val, test))
df.loc[tuple(map(str, params.values()))] = [train_mae, val_mae, test_mae]
df.to_csv('grid.csv')
df.index.names = grid.keys()